{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp meta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from fastcore.imports import *\n",
    "from fastcore.test import *\n",
    "from contextlib import contextmanager\n",
    "from copy import copy\n",
    "import inspect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastcore.foundation import *\n",
    "from nbdev.showdoc import *\n",
    "from fastcore.nb_imports import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Meta\n",
    "\n",
    "> Metaclasses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See this [blog post](https://realpython.com/python-metaclasses/) for more information about metaclasses. \n",
    "\n",
    "- `FixSigMeta` preserves information that enables [intropsection of signatures](https://www.python.org/dev/peps/pep-0362/#:~:text=Python%20has%20always%20supported%20powerful,fully%20reconstruct%20the%20function's%20signature.) (i.e. tab completion in IDEs) when certain types of inheritence would otherwise obfuscate this introspection.\n",
    "- `PrePostInitMeta` ensures that the classes defined with it run `__pre_init__` and `__post_init__` (without having to write `self.__pre_init__()` and `self.__post_init__()`  in the actual `init`\n",
    "- `NewChkMeta` gives the `PrePostInitMeta` functionality and ensures classes defined with it don't re-create an object of their type whenever it's passed to the constructor\n",
    "- `BypassNewMeta` ensures classes defined with it can easily be casted form objects they subclass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def test_sig(f, b):\n",
    "    \"Test the signature of an object\"\n",
    "    test_eq(str(inspect.signature(f)), b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def func_1(h,i,j): pass\n",
    "def func_2(h,i=3, j=[5,6]): pass\n",
    "\n",
    "class T:\n",
    "    def __init__(self, a, b): pass\n",
    "\n",
    "test_sig(func_1, '(h, i, j)')\n",
    "test_sig(func_2, '(h, i=3, j=[5, 6])')\n",
    "test_sig(T, '(a, b)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export \n",
    "def _rm_self(sig):\n",
    "    sigd = dict(sig.parameters)\n",
    "    sigd.pop('self')\n",
    "    return sig.replace(parameters=sigd.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class FixSigMeta(type):\n",
    "    \"A metaclass that fixes the signature on classes that override `__new__`\"\n",
    "    def __new__(cls, name, bases, dict):\n",
    "        res = super().__new__(cls, name, bases, dict)\n",
    "        if res.__init__ is not object.__init__: res.__signature__ = _rm_self(inspect.signature(res.__init__))\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### FixSigMeta\n",
       "\n",
       ">      FixSigMeta (name, bases, dict)\n",
       "\n",
       "A metaclass that fixes the signature on classes that override `__new__`"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### FixSigMeta\n",
       "\n",
       ">      FixSigMeta (name, bases, dict)\n",
       "\n",
       "A metaclass that fixes the signature on classes that override `__new__`"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(FixSigMeta, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When you inherit from a class that defines `__new__`, or a metaclass that defines `__call__`, the signature of your `__init__` method is obfuscated such that tab completion no longer works.  `FixSigMeta` fixes this issue and restores signatures.\n",
    "\n",
    "To understand what `FixSigMeta` does, it is useful to inspect an object's signature.  You can inspect the signature of an object with `inspect.signature`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (a, b, c)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class T:\n",
    "    def __init__(self, a, b, c): pass\n",
    "    \n",
    "inspect.signature(T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This corresponds to tab completion working in the normal way:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img alt=\"Tab completion in a Jupyter Notebook.\" caption=\"\" src=\"images/att_00005.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, when you inherhit from a class that defines `__new__` or a metaclass that defines `__call__` this obfuscates the signature by overriding your class with the signature of `__new__`, which prevents tab completion from displaying useful information:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (d, e, f)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Foo:\n",
    "    def __new__(self, **args): pass\n",
    "\n",
    "class Bar(Foo):\n",
    "    def __init__(self, d, e, f): pass\n",
    "    \n",
    "inspect.signature(Bar)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img alt=\"Tab completion in a Jupyter Notebook.\" caption=\"\" src=\"images/att_00006.png\">\n",
    "\n",
    "Finally, the signature and tab completion can be restored by inheriting from the metaclass `FixSigMeta` as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (d, e, f)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Bar(Foo, metaclass=FixSigMeta):\n",
    "    def __init__(self, d, e, f): pass\n",
    "    \n",
    "test_sig(Bar, '(d, e, f)')\n",
    "inspect.signature(Bar)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img alt=\"Tab completion in a Jupyter Notebook.\" caption=\"\" src=\"images/att_00007.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you need to define a metaclass that overrides `__call__` (as done in `PrePostInitMeta`), you need to inherit from `FixSigMeta` instead of `type` when constructing the metaclass to preserve the signature in `__init__`.  Be careful not to override `__new__` when doing this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestMeta(FixSigMeta):\n",
    "    # __new__ comes from FixSigMeta\n",
    "    def __call__(cls, *args, **kwargs): pass\n",
    "    \n",
    "class T(metaclass=TestMeta):\n",
    "    def __init__(self, a, b): pass\n",
    "    \n",
    "test_sig(T, '(a, b)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the other hand, if you fail to inherit from `FixSigMeta` when inheriting from a metaclass that overrides `__call__`, your signature will reflect that of `__call__` instead (which is often undesirable):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GenericMeta(type):\n",
    "    \"A boilerplate metaclass that doesn't do anything for testing.\"\n",
    "    def __new__(cls, name, bases, dict):\n",
    "        return super().__new__(cls, name, bases, dict)\n",
    "    def __call__(cls, *args, **kwargs): pass\n",
    "\n",
    "class T2(metaclass=GenericMeta):\n",
    "    def __init__(self, a, b): pass\n",
    "\n",
    "# We can avoid this by inheriting from the metaclass `FixSigMeta`\n",
    "test_sig(T2, '(*args, **kwargs)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class PrePostInitMeta(FixSigMeta):\n",
    "    \"A metaclass that calls optional `__pre_init__` and `__post_init__` methods\"\n",
    "    def __call__(cls, *args, **kwargs):\n",
    "        res = cls.__new__(cls)\n",
    "        if type(res)==cls:\n",
    "            if hasattr(res,'__pre_init__'): res.__pre_init__(*args,**kwargs)\n",
    "            res.__init__(*args,**kwargs)\n",
    "            if hasattr(res,'__post_init__'): res.__post_init__(*args,**kwargs)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### PrePostInitMeta\n",
       "\n",
       ">      PrePostInitMeta (name, bases, dict)\n",
       "\n",
       "A metaclass that calls optional `__pre_init__` and `__post_init__` methods"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### PrePostInitMeta\n",
       "\n",
       ">      PrePostInitMeta (name, bases, dict)\n",
       "\n",
       "A metaclass that calls optional `__pre_init__` and `__post_init__` methods"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(PrePostInitMeta, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`__pre_init__`  and `__post_init__` are useful for initializing variables or performing tasks prior to or after `__init__` being called, respectively.  Fore example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T(metaclass=PrePostInitMeta):\n",
    "    def __pre_init__(self):  self.a  = 0; \n",
    "    def __init__(self,b=0):  self.b = self.a + 1; assert self.b==1\n",
    "    def __post_init__(self): self.c = self.b + 2; assert self.c==3\n",
    "\n",
    "t = _T()\n",
    "test_eq(t.a, 0) # set with __pre_init__\n",
    "test_eq(t.b, 1) # set with __init__\n",
    "test_eq(t.c, 3) # set with __post_init__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One use for `PrePostInitMeta` is avoiding the `__super__().__init__()` boilerplate associated with subclassing, such as used in `AutoInit`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class AutoInit(metaclass=PrePostInitMeta):\n",
    "    \"Same as `object`, but no need for subclasses to call `super().__init__`\"\n",
    "    def __pre_init__(self, *args, **kwargs): super().__init__(*args, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is normally used as a [mixin](https://www.residentmar.io/2019/07/07/python-mixins.html), eg:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestParent():\n",
    "    def __init__(self): self.h = 10\n",
    "        \n",
    "class TestChild(AutoInit, TestParent):\n",
    "    def __init__(self): self.k = self.h + 2\n",
    "    \n",
    "t = TestChild()\n",
    "test_eq(t.h, 10) # h=10 is initialized in the parent class\n",
    "test_eq(t.k, 12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class NewChkMeta(FixSigMeta):\n",
    "    \"Metaclass to avoid recreating object passed to constructor\"\n",
    "    def __call__(cls, x=None, *args, **kwargs):\n",
    "        if not args and not kwargs and x is not None and isinstance(x,cls): return x\n",
    "        res = super().__call__(*((x,) + args), **kwargs)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### NewChkMeta\n",
       "\n",
       ">      NewChkMeta (name, bases, dict)\n",
       "\n",
       "Metaclass to avoid recreating object passed to constructor"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### NewChkMeta\n",
       "\n",
       ">      NewChkMeta (name, bases, dict)\n",
       "\n",
       "Metaclass to avoid recreating object passed to constructor"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(NewChkMeta, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`NewChkMeta` is used when an object of the same type is the first argument to your class's constructor (i.e. the `__init__` function), and you would rather it not create a new object but point to the same exact object.  \n",
    "\n",
    "This is used in `L`, for example, to avoid creating a new object when the object is already of type `L`.  This allows the users to defenisvely instantiate an `L` object and just return a reference to the same object if it already happens to be of type `L`.\n",
    "\n",
    "For example, the below class `_T` **optionally** accepts an object `o` as its first argument.  A new object is returned upon instantiation per usual:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T():\n",
    "    \"Testing\"\n",
    "    def __init__(self, o): \n",
    "        # if `o` is not an object without an attribute `foo`, set foo = 1\n",
    "        self.foo = getattr(o,'foo',1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = _T(3)\n",
    "test_eq(t.foo,1) # 1 was not of type _T, so foo = 1\n",
    "\n",
    "t2 = _T(t) #t1 is of type _T\n",
    "assert t is not t2 # t1 and t2 are different objects"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, if we want `_T` to return a reference to the same object when passed an an object of type `_T` we can inherit from the `NewChkMeta` class as illustrated below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _T(metaclass=NewChkMeta):\n",
    "    \"Testing with metaclass NewChkMeta\"\n",
    "    def __init__(self, o=None, b=1):\n",
    "        # if `o` is not an object without an attribute `foo`, set foo = 1\n",
    "        self.foo = getattr(o,'foo',1)\n",
    "        self.b = b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now test `t` and `t2` are now pointing at the same object when using this new definition of `_T`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = _T(3)\n",
    "test_eq(t.foo,1) # 1 was not of type _T, so foo = 1\n",
    "\n",
    "t2 = _T(t) # t2 will now reference t\n",
    "\n",
    "test_is(t, t2) # t and t2 are the same object\n",
    "t2.foo = 5 # this will also change t.foo to 5 because it is the same object\n",
    "test_eq(t.foo, 5)\n",
    "test_eq(t2.foo, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, there is one exception to how `NewChkMeta` works.  **If you pass any additional arguments in the constructor a new object is returned**, even if the first object is of the same type.  For example, consider the below example where we pass the additional argument `b` into the constructor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t3 = _T(t, b=1)\n",
    "assert t3 is not t\n",
    "\n",
    "t4 = _T(t) # without any arguments the constructor will return a reference to the same object\n",
    "assert t4 is t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, it should be noted that `NewChkMeta` as well as all other metaclases in this section, inherit from `FixSigMeta`.  This means class signatures will always be preserved when inheriting from this metaclass (see docs for `FixSigMeta` for more details):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sig(_T, '(o=None, b=1)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class BypassNewMeta(FixSigMeta):\n",
    "    \"Metaclass: casts `x` to this class if it's of type `cls._bypass_type`\"\n",
    "    def __call__(cls, x=None, *args, **kwargs):\n",
    "        if hasattr(cls, '_new_meta'): x = cls._new_meta(x, *args, **kwargs)\n",
    "        elif not isinstance(x,getattr(cls,'_bypass_type',object)) or len(args) or len(kwargs):\n",
    "            x = super().__call__(*((x,)+args), **kwargs)\n",
    "        if cls!=x.__class__: x.__class__ = cls\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### BypassNewMeta\n",
       "\n",
       ">      BypassNewMeta (name, bases, dict)\n",
       "\n",
       "Metaclass: casts `x` to this class if it's of type `cls._bypass_type`"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "### BypassNewMeta\n",
       "\n",
       ">      BypassNewMeta (name, bases, dict)\n",
       "\n",
       "Metaclass: casts `x` to this class if it's of type `cls._bypass_type`"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(BypassNewMeta, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`BypassNewMeta` is identical to `NewChkMeta`, except for checking for a class as the same type, we instead check for a class of type specified in attribute `_bypass_type`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In NewChkMeta, objects of the same type passed to the constructor (without arguments) would result into a new variable referencing the same object.  However, with `BypassNewMeta` this only occurs if the type matches the `_bypass_type` of the class you are defining:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _TestA: pass\n",
    "class _TestB: pass\n",
    "\n",
    "class _T(_TestA, metaclass=BypassNewMeta):\n",
    "    _bypass_type=_TestB\n",
    "    def __init__(self,x): self.x=x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the below example, `t` does not refer to `t2` because `t` is of type `_TestA` while `_T._bypass_type` is of type `TestB`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = _TestA()\n",
    "t2 = _T(t)\n",
    "assert t is not t2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, if `t` is set to `_TestB` to match `_T._bypass_type`, then both `t` and `t2` will refer to the same object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = _TestB()\n",
    "t2 = _T(t)\n",
    "t2.new_attr = 15\n",
    "\n",
    "test_is(t, t2)\n",
    "# since t2 just references t these will be the same\n",
    "test_eq(t.new_attr, t2.new_attr)\n",
    "\n",
    "# likewise, chaning an attribute on t will also affect t2 because they both point to the same object.\n",
    "t.new_attr = 9\n",
    "test_eq(t2.new_attr, 9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metaprogramming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def empty2none(p):\n",
    "    \"Replace `Parameter.empty` with `None`\"\n",
    "    return None if p==inspect.Parameter.empty else p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def anno_dict(f):\n",
    "    \"`__annotation__ dictionary with `empty` cast to `None`, returning empty if doesn't exist\"\n",
    "    return {k:empty2none(v) for k,v in getattr(f, '__annotations__', {}).items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _f(a:int, b:L)->str: ...\n",
    "test_eq(anno_dict(_f), {'a': int, 'b': L, 'return': str})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _mk_param(n,d=None): return inspect.Parameter(n, inspect.Parameter.KEYWORD_ONLY, default=d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def use_kwargs_dict(keep=False, **kwargs):\n",
    "    \"Decorator: replace `**kwargs` in signature with `names` params\"\n",
    "    def _f(f):\n",
    "        sig = inspect.signature(f)\n",
    "        sigd = dict(sig.parameters)\n",
    "        k = sigd.pop('kwargs')\n",
    "        s2 = {n:_mk_param(n,d) for n,d in kwargs.items() if n not in sigd}\n",
    "        sigd.update(s2)\n",
    "        if keep: sigd['kwargs'] = k\n",
    "        f.__signature__ = sig.replace(parameters=sigd.values())\n",
    "        return f\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Replace all `**kwargs` with named arguments like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@use_kwargs_dict(y=1,z=None)\n",
    "def foo(a, b=1, **kwargs): pass\n",
    "\n",
    "test_sig(foo, '(a, b=1, *, y=1, z=None)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add named arguments, but optionally keep `**kwargs` by setting `keep=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@use_kwargs_dict(y=1,z=None, keep=True)\n",
    "def foo(a, b=1, **kwargs): pass\n",
    "\n",
    "test_sig(foo, '(a, b=1, *, y=1, z=None, **kwargs)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def use_kwargs(names, keep=False):\n",
    "    \"Decorator: replace `**kwargs` in signature with `names` params\"\n",
    "    def _f(f):\n",
    "        sig = inspect.signature(f)\n",
    "        sigd = dict(sig.parameters)\n",
    "        k = sigd.pop('kwargs')\n",
    "        s2 = {n:_mk_param(n) for n in names if n not in sigd}\n",
    "        sigd.update(s2)\n",
    "        if keep: sigd['kwargs'] = k\n",
    "        f.__signature__ = sig.replace(parameters=sigd.values())\n",
    "        return f\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`use_kwargs` is different than `use_kwargs_dict` as it only replaces `**kwargs` with named parameters without any default values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@use_kwargs(['y', 'z'])\n",
    "def foo(a, b=1, **kwargs): pass\n",
    "\n",
    "test_sig(foo, '(a, b=1, *, y=None, z=None)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may optionally keep the `**kwargs` argument in your signature by setting `keep=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@use_kwargs(['y', 'z'], keep=True)\n",
    "def foo(a, *args, b=1, **kwargs): pass\n",
    "test_sig(foo, '(a, *args, b=1, y=None, z=None, **kwargs)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def delegates(to:FunctionType=None, # Delegatee\n",
    "              keep=False, # Keep `kwargs` in decorated function?\n",
    "              but:list=None): # Exclude these parameters from signature\n",
    "    \"Decorator: replace `**kwargs` in signature with params from `to`\"\n",
    "    if but is None: but = []\n",
    "    def _f(f):\n",
    "        if to is None: to_f,from_f = f.__base__.__init__,f.__init__\n",
    "        else:          to_f,from_f = to.__init__ if isinstance(to,type) else to,f\n",
    "        from_f = getattr(from_f,'__func__',from_f)\n",
    "        to_f = getattr(to_f,'__func__',to_f)\n",
    "        if hasattr(from_f,'__delwrap__'): return f\n",
    "        sig = inspect.signature(from_f)\n",
    "        sigd = dict(sig.parameters)\n",
    "        k = sigd.pop('kwargs')\n",
    "        s2 = {k:v.replace(kind=inspect.Parameter.KEYWORD_ONLY) for k,v in inspect.signature(to_f).parameters.items()\n",
    "              if v.default != inspect.Parameter.empty and k not in sigd and k not in but}\n",
    "        anno = {k:v for k,v in getattr(to_f, \"__annotations__\", {}).items() if k not in sigd and k not in but}\n",
    "        sigd.update(s2)\n",
    "        if keep: sigd['kwargs'] = k\n",
    "        else: from_f.__delwrap__ = to_f\n",
    "        from_f.__signature__ = sig.replace(parameters=sigd.values())\n",
    "        if hasattr(from_f, '__annotations__'): from_f.__annotations__.update(anno)\n",
    "        return f\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A common Python idiom is to accept `**kwargs` in addition to named parameters that are passed onto other function calls. It is especially common to use `**kwargs` when you want to give the user an option to override default parameters of any functions or methods being called by the parent function.\n",
    "\n",
    "For example, suppose we have have a function `foo` that passes arguments to `baz` like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def baz(a, b:int=2, c:int=3): return a + b + c\n",
    "\n",
    "def foo(c, a, **kwargs):\n",
    "    return c + baz(a, **kwargs)\n",
    "\n",
    "assert foo(c=1, a=1) == 7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The problem with this approach is the api for `foo` is obfuscated. Users cannot introspect what the valid arguments for `**kwargs` are without reading the source code.  When a user tries tries to introspect the signature of `foo`, they are presented with this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (c, a, **kwargs)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inspect.signature(foo)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can address this issue by using the decorator `delegates` to include parameters from other functions.  For example, if we apply the `delegates` decorator to `foo` to include parameters from `baz`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (c, a, *, b: int = 2)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@delegates(baz)\n",
    "def foo(c, a, **kwargs):\n",
    "    return c + baz(a, **kwargs)\n",
    "\n",
    "test_sig(foo, '(c, a, *, b: int = 2)')\n",
    "inspect.signature(foo)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can optionally decide to keep `**kwargs` by setting `keep=True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (c, a, *, b: int = 2, **kwargs)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@delegates(baz, keep=True)\n",
    "def foo(c, a, **kwargs):\n",
    "    return c + baz(a, **kwargs)\n",
    "\n",
    "inspect.signature(foo)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is important to note that **only parameters with default parameters are included**.  For example, in the below scenario only `c`, but NOT `e` and `d` are included in the signature of `foo` after applying `delegates`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (a, b=1, *, c=2)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def basefoo(e, d, c=2): pass\n",
    "\n",
    "@delegates(basefoo)\n",
    "def foo(a, b=1, **kwargs): pass\n",
    "inspect.signature(foo) # e and d are not included b/c they don't have default parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The reason that required arguments (i.e. those without default parameters) are automatically excluded is that you should be explicitly implementing required arguments into your function's signature rather than relying on `delegates`.\n",
    "\n",
    "Additionally, you can exclude specific parameters from being included in the signature with the  `but` parameter.  In the example below, we exclude the parameter `d`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (a, b=1, *, c=2)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def basefoo(e, c=2, d=3): pass\n",
    "\n",
    "@delegates(basefoo, but= ['d'])\n",
    "def foo(a, b=1, **kwargs): pass\n",
    "\n",
    "test_sig(foo, '(a, b=1, *, c=2)')\n",
    "inspect.signature(foo)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use `delegates` between methods in a class.  Here is an example of `delegates` with class methods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# example 1: class methods\n",
    "class _T():\n",
    "    @classmethod\n",
    "    def foo(cls, a=1, b=2):\n",
    "        pass\n",
    "    \n",
    "    @classmethod\n",
    "    @delegates(foo)\n",
    "    def bar(cls, c=3, **kwargs):\n",
    "        pass\n",
    "\n",
    "test_sig(_T.bar, '(c=3, *, a=1, b=2)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is the same example with instance methods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# example 2: instance methods\n",
    "class _T():\n",
    "    def foo(self, a=1, b=2):\n",
    "        pass\n",
    "    \n",
    "    @delegates(foo)\n",
    "    def bar(self, c=3, **kwargs):\n",
    "        pass\n",
    "\n",
    "t = _T()\n",
    "test_sig(t.bar, '(c=3, *, a=1, b=2)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also delegate between classes.  By default, the `delegates` decorator will delegate to the superclass:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BaseFoo:\n",
    "    def __init__(self, e, c=2): pass\n",
    "\n",
    "@delegates()# since no argument was passsed here we delegate to the superclass\n",
    "class Foo(BaseFoo):\n",
    "    def __init__(self, a, b=1, **kwargs): super().__init__(**kwargs)\n",
    "\n",
    "test_sig(Foo, '(a, b=1, *, c=2)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def method(f):\n",
    "    \"Mark `f` as a method\"\n",
    "    # `1` is a dummy instance since Py3 doesn't allow `None` any more\n",
    "    return MethodType(f, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `method` function is used to change a function's type to a method.  In the below example we change the type of `a` from a function to a method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def a(x=2): return x + 1\n",
    "assert type(a).__name__ == 'function'\n",
    "\n",
    "a = method(a)\n",
    "assert type(a).__name__ == 'method'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _funcs_kwargs(cls, as_method):\n",
    "    old_init = cls.__init__\n",
    "    def _init(self, *args, **kwargs):\n",
    "        for k in cls._methods:\n",
    "            arg = kwargs.pop(k,None)\n",
    "            if arg is not None:\n",
    "                if as_method: arg = method(arg)\n",
    "                if isinstance(arg,MethodType): arg = MethodType(arg.__func__, self)\n",
    "                setattr(self, k, arg)\n",
    "        old_init(self, *args, **kwargs)\n",
    "    functools.update_wrapper(_init, old_init)\n",
    "    cls.__init__ = use_kwargs(cls._methods)(_init)\n",
    "    if hasattr(cls, '__signature__'): cls.__signature__ = _rm_self(inspect.signature(cls.__init__))\n",
    "    return cls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def funcs_kwargs(as_method=False):\n",
    "    \"Replace methods in `cls._methods` with those from `kwargs`\"\n",
    "    if callable(as_method): return _funcs_kwargs(as_method, False)\n",
    "    return partial(_funcs_kwargs, as_method=as_method)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `func_kwargs` decorator allows you to add a list of functions or methods to an existing class.  You must set this list as a class attribute named `_methods` when defining your class.  Additionally, you must incldue the `**kwargs` argument in the `___init__` method of your class.\n",
    "\n",
    "After defining your class this way, you can add functions to your class upon instantation as illusrated below.\n",
    "\n",
    "For example, we define class `T` to allow adding the function `b` to class `T` as follows (note that this function is stored as an attribute of `T` and doesn't have access to `cls` or `self`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@funcs_kwargs\n",
    "class T:\n",
    "    _methods=['b'] # allows you to add method b upon instantiation\n",
    "    def __init__(self, f=1, **kwargs): pass # don't forget to include **kwargs in __init__\n",
    "    def a(self): return 1\n",
    "    def b(self): return 2\n",
    "    \n",
    "t = T()\n",
    "test_eq(t.a(), 1)\n",
    "test_eq(t.b(), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because we defined the class `T` this way, the signature of `T` indicates the option to add the function or method(s) specified in `_methods`.  In this example, `b` is added to the signature:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Signature (f=1, *, b=None)>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sig(T, '(f=1, *, b=None)')\n",
    "inspect.signature(T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can now add the function `b` to class `T` upon instantiation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _new_func(): return 5\n",
    "\n",
    "t = T(b = _new_func)\n",
    "test_eq(t.b(), 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you try to add a function with a name not listed in `_methods` it will be ignored. In the below example, the attempt to add a function named `a` is ignored:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = T(a = lambda:3)\n",
    "test_eq(t.a(), 1) # the attempt to add a is ignored and uses the original method instead."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that you can also add methods not defined in the original class as long it is specified in the `_methods` attribute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@funcs_kwargs\n",
    "class T:\n",
    "    _methods=['c']\n",
    "    def __init__(self, f=1, **kwargs): pass\n",
    "\n",
    "t = T(c = lambda: 4)\n",
    "test_eq(t.c(), 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Until now, these examples showed how to add functions stored as an instance attribute without access to `self`.  However, if you need access to `self` you can set `as_method=True` in the `func_kwargs` decorator to add a method instead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _f(self,a=1): return self.num + a # access the num attribute from the instance\n",
    "\n",
    "@funcs_kwargs(as_method=True)\n",
    "class T: \n",
    "    _methods=['b']\n",
    "    num = 5\n",
    "    \n",
    "t = T(b = _f) # adds method b\n",
    "test_eq(t.b(5), 10) # self.num + 5 = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is an example of how you might use this functionality with inheritence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _f(self,a=1): return self.num * a #multiply instead of add \n",
    "\n",
    "class T2(T):\n",
    "    def __init__(self,num):\n",
    "        super().__init__(b = _f) # add method b from the super class\n",
    "        self.num=num\n",
    "        \n",
    "t = T2(num=3)\n",
    "test_eq(t.b(a=5), 15) # 3 * 5 = 15\n",
    "test_sig(T2, '(num)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "def _g(a=1): return a+1\n",
    "class T3(T): b = staticmethod(_g)\n",
    "t = T3()\n",
    "test_eq(t.b(2), 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "#test funcs_kwargs works with PrePostInitMeta\n",
    "class A(metaclass=PrePostInitMeta): pass\n",
    "\n",
    "@funcs_kwargs\n",
    "class B(A):\n",
    "    _methods = ['m1']\n",
    "    def __init__(self, **kwargs): pass\n",
    "    \n",
    "test_sig(B, '(*, m1=None)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
